{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c69d1f48d21cd2b4",
   "metadata": {},
   "source": [
    "# FlashRank Reranker\n",
    "\n",
    "- Author: [Hwayoung Cha](https://github.com/forwardyoung)\n",
    "- Peer Review: []()\n",
    "- Proofread : [BokyungisaGod](https://github.com/BokyungisaGod)\n",
    "- This is a part of [LangChain Open Tutorial](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial)\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/11-Reranker/04-FlashRank-Reranker.ipynb)[![Open in GitHub](https://img.shields.io/badge/Open%20in%20GitHub-181717?style=flat-square&logo=github&logoColor=white)](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/11-Reranker/04-FlashRank-Reranker.ipynb)\n",
    "## Overview\n",
    "\n",
    "> [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) is an ultra-lightweight and ultra-fast Python library designed to add reranking to existing search and **retrieval** pipelines. It is based on state-of-the-art (**SoTA**) **cross-encoders**.\n",
    "\n",
    "This notebook introduces the use of ```FlashRank-Reranker``` within the LangChain framework, showcasing how to apply reranking techniques to improve the quality of search or **retrieval** results. It provides practical code examples and explanations for integrating ```FlashRank``` into a LangChain pipeline, highlighting its efficiency and effectiveness. The focus is on leveraging ```FlashRank```'s capabilities to enhance the ranking of outputs in a streamlined and scalable way.\n",
    "\n",
    "### Table of Contents\n",
    "\n",
    "- [Overview](#overview)\n",
    "- [Environement Setup](#environment-setup)\n",
    "- [FlashRankRerank](#flashrankrerank)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7431102d93a694f",
   "metadata": {},
   "source": [
    "## Environment Setup\n",
    "\n",
    "Set up the environment. You may refer to [Environment Setup](https://wikidocs.net/257836) for more details.\n",
    "\n",
    "**[Note]**\n",
    "- ```langchain-opentutorial``` is a package that provides a set of easy-to-use environment setup, useful functions and utilities for tutorials. \n",
    "- You can checkout the [```langchain-opentutorial```](https://github.com/LangChain-OpenTutorial/langchain-opentutorial-pypi) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "501e9dfa010f326a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set environment variables\n",
    "from langchain_opentutorial import set_env\n",
    "\n",
    "set_env(\n",
    "    {\n",
    "        \"OPENAI_API_KEY\": \"\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d83ee066d91fb4f",
   "metadata": {},
   "source": [
    "You can alternatively set OPENAI_API_KEY in ```.env``` file and load it.\n",
    "\n",
    "[Note] This is not necessary if you've already set OPENAI_API_KEY in previous steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abed94e9253ec29e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration file to manage API keys as environment variables\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# Load API key information\n",
    "load_dotenv(override=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "687b4939",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture --no-stderr\n",
    "%pip install langchain-opentutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af16502c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages\n",
    "from langchain_opentutorial import package\n",
    "\n",
    "package.install(\n",
    "    [\n",
    "        \"flashrank\"\n",
    "    ],\n",
    "    verbose=False,\n",
    "    upgrade=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "43856bcf1e8f0c63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pretty_print_docs(docs):\n",
    "    print(\n",
    "        f\"\\n{'-' * 100}\\n\".join(\n",
    "            [\n",
    "                f\"Document {i+1}:\\n\\n{d.page_content}\\nMetadata: {d.metadata}\"\n",
    "                for i, d in enumerate(docs)\n",
    "            ]\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c7d03faa97bf809",
   "metadata": {},
   "source": [
    "## FlashrankRerank\n",
    "\n",
    "Load data for a simple example and create a **retriever**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d934121fd476be",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import TextLoader\n",
    "from langchain_community.vectorstores import FAISS\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "\n",
    "# Load the documents\n",
    "documents = TextLoader(\"./data/appendix-keywords.txt\").load()\n",
    "\n",
    "# Initialized the text splitter\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
    "\n",
    "# Split the documents\n",
    "texts = text_splitter.split_documents(documents)\n",
    "\n",
    "# Add a unique ID to each text\n",
    "for idx, text in enumerate(texts):\n",
    "    text.metadata[\"id\"] = idx\n",
    "    \n",
    "# Initialize the retriever\n",
    "retriever = FAISS.from_documents(\n",
    "    texts, OpenAIEmbeddings()\n",
    ").as_retriever(search_kwargs={\"k\": 10})\n",
    "\n",
    "# query\n",
    "query = \"Tell me about Word2Vec\"\n",
    "\n",
    "# Search for documents\n",
    "docs = retriever.invoke(query)\n",
    "\n",
    "# Print the document\n",
    "pretty_print_docs(docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea07e244c9171d26",
   "metadata": {},
   "source": [
    "Now, let's wrap the base **retriever** with a ```ContextualCompressionRetriever``` and use ```FlashrankRerank``` as the compressor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a21f9f025132c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers import ContextualCompressionRetriever\n",
    "from langchain.retrievers.document_compressors import FlashrankRerank\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# Initialize the LLM\n",
    "llm = ChatOpenAI(temperature=0)\n",
    "\n",
    "# Initialize the FlshrankRerank\n",
    "compressor = FlashrankRerank(model=\"ms-marco-MultiBERT-L-12\")\n",
    "\n",
    "# Initialize the ContextualCompressioinRetriever\n",
    "compression_retriever = ContextualCompressionRetriever(\n",
    "    base_compressor=compressor, base_retriever=retriever\n",
    ")\n",
    "\n",
    "# Search for compressed documents\n",
    "compressed_docs = compression_retriever.invoke(\n",
    "    \"Tell me about Word2Vec.\"\n",
    ")\n",
    "\n",
    "# Print the document ID\n",
    "print([doc.metadata[\"id\"] for doc in compressed_docs])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a147fc787860bac",
   "metadata": {},
   "source": [
    "Compare the results after reranker is applied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "732f27a4e8b3d4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print the results of document compressions\n",
    "pretty_print_docs(compressed_docs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain-opentutorial-GHgbjDj7-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
